from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import pathlib
import numpy as np
from PIL import Image
from tqdm import tqdm
import json


def main():
    dataset_path = "/scratch/shared/beegfs/yuki/adiwol/joined_sym/all/"
    out_dir = "/scratch/shared/beegfs/chrisr/adiwol"
    img_size = 600
    batch_size = 8
    #dataset_path = "/scratch/local/ssd/datasets/imagenet-r256/train/"  # train:0 , val:2
    #out_dir = "./imagenet/train"
    os.makedirs(out_dir, exist_ok=True)

    print("creating dataset...")
    result_file = os.path.join(out_dir, "open_images_results.txt")
    files_to_remove = set()
    if os.path.isfile(result_file):
        print("Found predictions from a previous run.")
        with open(result_file, "r") as results_file:
            for line in tqdm(results_file):
                detections = json.loads(line)
                files_to_remove.add(detections["file"])
        print(f"removing {len(files_to_remove)} files from dataset")

    image_paths = []
    for root, dir_names, filenames in tqdm(os.walk(dataset_path)):
        for filename in filenames:
            if filename not in files_to_remove:
                image_paths.append(os.path.join(root, filename))

    def process_path(file_path):
        img = tf.io.read_file(file_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, [img_size, img_size])
        return img, file_path

    print("creating tf dataset")
    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    print(f"dataset has {len(dataset)} images")
    dataset = dataset.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    print("done")

    label_map = load_labels("oid_v4_label_map.pbtxt")
    print("loading model")
    model = load_model("faster_rcnn_inception_resnet_v2_atrous_oid_v4_2018_12_12")
    detect_fn = model.signatures['serving_default']

    print("running model")
    predict(dataset, detect_fn, label_map, out_dir)


def load_model(model_name):
    base_url = 'http://download.tensorflow.org/models/object_detection/'
    model_file = model_name + '.tar.gz'
    model_dir = tf.keras.utils.get_file(
        fname=model_name,
        origin=base_url + model_file,
        untar=True)
    model_dir = pathlib.Path(model_dir)/"saved_model"

    model = tf.saved_model.load(str(model_dir))
    return model


def load_labels(filename):
    base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
    label_dir = tf.keras.utils.get_file(fname=filename,
                                        origin=base_url + filename,
                                        untar=False)
    label_dir = pathlib.Path(label_dir)
    category_index = label_map_util.create_category_index_from_labelmap(str(label_dir), use_display_name=True)
    return category_index


def predict(dataset, detect_fn, label_map, out_dir, save_samples=False):
    with open(os.path.join(out_dir, "open_images_labels.txt"), "w") as label_file:
        json.dump(label_map, label_file)
    with open(os.path.join(out_dir, "open_images_results.txt"), "a") as result_file:
        for image_batch, path_batch in tqdm(dataset):
            image_batch = tf.dtypes.cast(image_batch, tf.uint8)
            detections = detect_fn(image_batch)

            path_batch = path_batch.numpy()
            classes = detections["detection_classes"].numpy().astype(np.int64)
            boxes = detections['detection_boxes'].numpy()
            scores = detections['detection_scores'].numpy()
            num_detections = detections['num_detections'].numpy().astype(np.int64)
            for i in range(image_batch.shape[0]):
                results = {
                    "file": path_batch[i].decode('utf-8'),
                    "classes": classes[i][:num_detections[i]].tolist(),
                    "boxes": boxes[i][:num_detections[i], :].tolist(),
                    "scores": scores[i][:num_detections[i]].tolist(),
                }
                result_file.write(json.dumps(results) + "\n")

            if save_samples:
                image_np_with_detections = image_batch.numpy()

                for i in range(image_batch.shape[0]):

                    viz_utils.visualize_boxes_and_labels_on_image_array(
                        image_np_with_detections[i],
                        boxes[i],
                        classes[i],
                        scores[i],
                        label_map,
                        use_normalized_coordinates=True,
                        max_boxes_to_draw=200,
                        min_score_thresh=.30,
                        agnostic_mode=False)

                    Image.fromarray(image_np_with_detections[i]).save(f"test{i}.jpg")

    print('Done')

if __name__ == "__main__":
    main()
